-
Notifications
You must be signed in to change notification settings - Fork 74k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement an algsimp optimization for dot operation. #28170
Conversation
Huh, I am surprised a simple One is the question about the protobuf int64 class. We have an AsInt64Slice helper for exactly this. |
About the RepeatedField and int64 type comment, in this push I copied lhs_contracting_dims and rhs_contracting_dims out to a std::vector at the beginning and manipulate the vector since then, as we do not actually modify the dnums of the dot anyway. Besides a few small typos here are the outstanding comments I can remember:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a few minor things, otherwise this looks great!
reshape->operand(0)->shape(), reshape->shape()); | ||
CHECK_EQ(lhs_contracting_dims.size(), 1); | ||
if ((unmodified_dims.size() != reshape->shape().rank() - 1) || | ||
(std::find_if(unmodified_dims.begin(), unmodified_dims.end(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
absl::c_find_if
, but better would be absl::c_any_of
.
|
||
// Check if reshape squishes some dims into one dim, and that this one | ||
// dim is the dot's lhs contracting dim. | ||
// The size of unmodified_dims should be N - 1, where N is the rank of the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit, if this line is a new paragraph, put a blank line before it. If it is not a new paragraph, flow it up with the previous line.
return nullptr; | ||
} | ||
|
||
// Check if reshape squishes some dims into one dim, and that this one |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit, "Check if...and that" has bad parallelism, best fix is probably "Check that...and that".
|
||
// Require single contracting dim to make the implementation easier to | ||
// track contracting dims. | ||
if (dnums.lhs_contracting_dimensions_size() != 1) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we pull the vectors lhs_contracting_dims
and rhs_contracting_dims
below this if statement, then we can simply do
// Comment explaining why we're pulling these into vectors, I am still not sure what is the problem this solves, it seems to be more complex to have two copies of one piece of data?
std::vector<int64> lhs_contracting_dims = {dnums.lhs_contracting_dims[0]};
The basic idea is that dot(reshape(transpose(A)), constant) can be replaced by dot(reshape(A), reshape(transpose(reshape(constant)))) if the effect of lhs transpose and reshape is to reorder elements in contracting dims. We inverse the reorder on the constant side so that it can be constant folded.
Thanks for approving the PR. Just while you are reviewing it I made some changes to not copy out the lhs_contracting_dims before hand. If this looks better or equally well I can leave it this way. I will fix things you suggested as well. |
Looks even better to me! |
@jlebar My previous push seemed to overwrite your approval. I just update the PR to fix review comments. Basically this revision just changes absl::c_find_if to absl::c_any_of and fixes comment format. If this revision looks good to you, could you approve it again? Thanks! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
\o/ @gbaned would you be willing to merge this?
@jlebar sure. I'm taking care of this PR and helping to get it merged. Thanks you! |
PiperOrigin-RevId: 245751800
PiperOrigin-RevId: 245765089
I had to roll this back due to a test failure; one of the CHECKs added here was failing. Overall this is kind of a good thing, it means that a real model used in production is affected by this change. :) I will see if there's an easy fix that I can make, and if not I'll give you a testcase. |
Thanks! Let me know if there is anything I can do on my side. |
@BinFan would you be willing to check the following patch for me? The first testcase in here is the one that was crashing for me.
|
@jlebar Thanks a lot for the patch! It looks good. And indeed I missed the size 1 dim case. I'm wondering if we should add check after filling in unmodified_transpose_dims something like
because I'm thinking of this example
I think this example would pass all the check: After pulling in reshape, lhs_contracting_dims={0,2,3}, and the transpose only permute dimensions 0 and 2. But the relative order of dim 2 and 3 does not change either, so should be no opportunity here. |
This one does not trigger the transformation. I didn't step through in a debugger, but I think it's because the I added this one as a testcase. |
Looks good to me. Thanks! |
No functional change. Relevant to PR #28170. PiperOrigin-RevId: 246051968
The basic idea is that dot(reshape(transpose(A)), constant) can be replaced by dot(reshape(A), reshape(transpose(reshape(constant)))) if the effect of lhs transpose and reshape is to reorder elements in lhs contracting dims. We apply inverse reordering on the constant side, and then the inverse reordering can be constant folded.